"""
Phase 7: Reviewer Response — Clean Index, Re-estimation, and Robustness
=========================================================================
Addresses reviewer concerns:
  1. Index contamination (shock filter, no splicing)
  2. Rate measurement heterogeneity (homogeneous 10y yield appendix)
  3. Structural break decomposition (balanced panel, splits, global factor)
  4. Conditional projections (pre-GFC vs post-GFC vs full sample)
  5. OADR threshold reframing on clean rolling index

All parts use japan_index_2c_rolling_clean as the primary DV.

Input:  japanification/data/processed/japan_panel_indexed.csv
        multilateral/data/processed/full_panel.csv (for projections + GDP weights)
Output: japanification/output/tables/phase7_*.md (5 files)
        japanification/data/processed/japan_panel_clean.csv (updated panel)
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
JAPAN_DIR = PROJECT_DIR / "japanification"
MULTILATERAL_DIR = PROJECT_DIR / "multilateral"
PROCESSED_DIR = JAPAN_DIR / "data" / "processed"
TABLE_DIR = JAPAN_DIR / "output" / "tables"

sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS


def fit_and_report(y, X, entity_ids, time_ids, feature_names, label):
    """Fit PanelGLS and return (model, result_df)."""
    model = PanelGLS()
    model.fit(y, X, entity_ids, time_ids)
    print(f"\n  {label}")
    model.summary(feature_names=feature_names)
    result_df = model.to_dataframe(feature_names=feature_names)
    result_df['model'] = label
    return model, result_df


def sig_stars(p):
    if p < 0.01:
        return '***'
    elif p < 0.05:
        return '**'
    elif p < 0.1:
        return '*'
    return ''


def fmt_coef(coef, p):
    """Format coefficient with significance stars."""
    return f"{coef:.3f}{sig_stars(p)}"


def fmt_coef_se(coef, se, p):
    """Format as coef*** (se)."""
    return f"{coef:.3f}{sig_stars(p)} ({se:.3f})"


# =====================================================================
# PART 1: Clean Index Construction
# =====================================================================
def part1_clean_index(df):
    print("=" * 70)
    print("PART 1: Clean Index Construction")
    print("=" * 70)

    lines = ["# Phase 7 Part 1: Clean Index Construction\n"]

    n_start = len(df)
    n_countries = df['iso3'].nunique()

    # --- Shock filter: exclude country-years with extreme growth contractions ---
    shock_threshold = -10.0
    shock_mask = df['rgdp_growth'] < shock_threshold
    n_shocks = shock_mask.sum()
    shock_obs = df[shock_mask][['iso3', 'year', 'rgdp_growth']].copy()
    shock_obs = shock_obs.sort_values('rgdp_growth')

    lines.append(f"## Shock Filter\n")
    lines.append(f"- Threshold: rgdp_growth < {shock_threshold}%")
    lines.append(f"- Observations excluded: {n_shocks}")
    lines.append(f"- Countries affected: {df.loc[shock_mask, 'iso3'].nunique()}\n")
    lines.append("Top excluded observations:\n")
    lines.append("| Country | Year | Growth |")
    lines.append("|---------|------|--------|")
    for _, row in shock_obs.head(20).iterrows():
        lines.append(f"| {row['iso3']} | {int(row['year'])} | {row['rgdp_growth']:.1f}% |")

    # Mark shocks (don't drop rows — just exclude from clean index)
    df['shock_excluded'] = shock_mask.astype(int)

    # --- Winsorize 2c index at p1/p99 (on non-shock obs) ---
    clean_mask = ~shock_mask & df['japan_index_2c'].notna()
    p1 = df.loc[clean_mask, 'japan_index_2c'].quantile(0.01)
    p99 = df.loc[clean_mask, 'japan_index_2c'].quantile(0.99)

    df['japan_index_2c_clean'] = df['japan_index_2c'].copy()
    df.loc[shock_mask, 'japan_index_2c_clean'] = np.nan
    df['japan_index_2c_clean'] = df['japan_index_2c_clean'].clip(lower=p1, upper=p99)

    lines.append(f"\n## Winsorization\n")
    lines.append(f"- p1 = {p1:.3f}, p99 = {p99:.3f}")
    lines.append(f"- Clean 2c obs: {df['japan_index_2c_clean'].notna().sum():,}")

    # --- Rolling 5yr MA of clean index ---
    df = df.sort_values(['iso3', 'year'])
    df['japan_index_2c_rolling_clean'] = (
        df.groupby('iso3')['japan_index_2c_clean']
        .transform(lambda x: x.rolling(5, min_periods=3).mean())
    )

    n_rolling = df['japan_index_2c_rolling_clean'].notna().sum()
    lines.append(f"\n## Rolling Clean Index (5yr MA, min_periods=3)\n")
    lines.append(f"- Observations: {n_rolling:,}")
    lines.append(f"- Countries: {df.loc[df['japan_index_2c_rolling_clean'].notna(), 'iso3'].nunique()}")
    lines.append(f"- Mean: {df['japan_index_2c_rolling_clean'].mean():.4f}")
    lines.append(f"- SD: {df['japan_index_2c_rolling_clean'].std():.4f}")

    # --- Rolling component series ---
    for comp, col in [('z_growth', 'z_growth'), ('z_inflation', 'z_inflation')]:
        df[f'{comp}_clean'] = df[col].copy()
        df.loc[shock_mask, f'{comp}_clean'] = np.nan
        df[f'{comp}_rolling'] = (
            df.groupby('iso3')[f'{comp}_clean']
            .transform(lambda x: x.rolling(5, min_periods=3).mean())
        )

    lines.append(f"\n## Rolling Component Series\n")
    for comp in ['z_growth_rolling', 'z_inflation_rolling']:
        n = df[comp].notna().sum()
        lines.append(f"- {comp}: {n:,} obs, mean={df[comp].mean():.4f}, SD={df[comp].std():.4f}")

    # --- Homogeneous rate index (10y govt bond only) ---
    bond_mask = df['govt_bond_10y'].notna() & ~shock_mask
    n_bond = bond_mask.sum()
    n_bond_countries = df.loc[bond_mask, 'iso3'].nunique()

    # Standardize 10y yield on the bond subsample
    r_mean = df.loc[bond_mask, 'govt_bond_10y'].mean()
    r_std = df.loc[bond_mask, 'govt_bond_10y'].std()
    df['z_rate_10y'] = np.nan
    df.loc[bond_mask, 'z_rate_10y'] = (df.loc[bond_mask, 'govt_bond_10y'] - r_mean) / r_std

    # 3c index using only homogeneous 10y yield
    df['japan_index_3c_homog'] = np.nan
    df.loc[bond_mask, 'japan_index_3c_homog'] = -(
        df.loc[bond_mask, 'z_growth'] +
        df.loc[bond_mask, 'z_inflation'] +
        df.loc[bond_mask, 'z_rate_10y']
    ) / 3.0

    lines.append(f"\n## Homogeneous Rate Index (10y govt bond only)\n")
    lines.append(f"- Bond yield obs: {n_bond:,}")
    lines.append(f"- Countries: {n_bond_countries}")
    lines.append(f"- 3c homog obs: {df['japan_index_3c_homog'].notna().sum():,}")
    lines.append(f"- Mean: {df['japan_index_3c_homog'].mean():.4f}")
    lines.append(f"- SD: {df['japan_index_3c_homog'].std():.4f}")

    # --- Comparison: clean vs original ---
    lines.append(f"\n## Index Comparison\n")
    lines.append("| Index | N | Mean | SD | Min | Max |")
    lines.append("|-------|---|------|----|----|-----|")
    for idx_name, col in [
        ('Original 2c', 'japan_index_2c'),
        ('Clean 2c', 'japan_index_2c_clean'),
        ('Rolling clean 2c', 'japan_index_2c_rolling_clean'),
        ('Original 3c', 'japan_index_3c'),
        ('Homog 3c (10y)', 'japan_index_3c_homog'),
    ]:
        s = df[col].dropna()
        lines.append(f"| {idx_name} | {len(s):,} | {s.mean():.3f} | {s.std():.3f} | {s.min():.3f} | {s.max():.3f} |")

    # Save
    md_path = TABLE_DIR / "phase7_clean_index_summary.md"
    md_path.write_text('\n'.join(lines), encoding='utf-8')
    print(f"\n  Saved: {md_path}")

    return df


# =====================================================================
# PART 2: Re-estimation on Clean Index
# =====================================================================
def part2_reestimation(df):
    print("\n" + "=" * 70)
    print("PART 2: Re-estimation on Clean Rolling Index")
    print("=" * 70)

    lines = ["# Phase 7 Part 2: Baseline Re-estimation on Clean Index\n"]
    dep_var = 'japan_index_2c_rolling_clean'
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    all_vars = demo_vars + controls

    all_results = []

    # --- 2.1 Baseline on clean rolling index ---
    est = df.dropna(subset=[dep_var] + all_vars).copy()
    print(f"\n  Clean rolling sample: {len(est):,} obs, {est['iso3'].nunique()} countries")

    m_base, r_base = fit_and_report(
        est[dep_var].values, est[all_vars].values,
        est['iso3'].values, est['year'].values,
        all_vars, "Baseline: Clean Rolling 2c Index"
    )
    all_results.append(r_base)

    lines.append("## 2.1 Baseline: Clean Rolling 2c Index\n")
    lines.append(f"- N = {m_base.n_obs:,}, N_countries = {m_base.n_countries}")
    lines.append(f"- R² = {m_base.r_squared:.4f}\n")
    lines.append("| Variable | Coefficient | SE | p-value |")
    lines.append("|----------|-------------|-------|---------|")
    for i, v in enumerate(all_vars):
        lines.append(f"| {v} | {fmt_coef(m_base.beta[i], m_base.pvalues[i])} | {m_base.se[i]:.3f} | {m_base.pvalues[i]:.4f} |")

    # --- 2.2 Component regressions on rolling outcomes ---
    lines.append("\n## 2.2 Component Regressions (Rolling)\n")
    lines.append("| Component | Z₁ | p | Z₂ | p | Z₃ | p | R² | N |")
    lines.append("|-----------|-----|---|-----|---|-----|---|-----|---|")

    comp_deps = {
        'Growth (rolling)': 'z_growth_rolling',
        'Inflation (rolling)': 'z_inflation_rolling',
    }

    for label, dep in comp_deps.items():
        comp_est = df.dropna(subset=[dep] + all_vars).copy()
        if len(comp_est) < 100:
            continue
        mc, rc = fit_and_report(
            comp_est[dep].values, comp_est[all_vars].values,
            comp_est['iso3'].values, comp_est['year'].values,
            all_vars, f"Component: {label}"
        )
        all_results.append(rc)
        lines.append(
            f"| {label} | {fmt_coef(mc.beta[0], mc.pvalues[0])} | {mc.pvalues[0]:.3f} | "
            f"{fmt_coef(mc.beta[1], mc.pvalues[1])} | {mc.pvalues[1]:.3f} | "
            f"{fmt_coef(mc.beta[2], mc.pvalues[2])} | {mc.pvalues[2]:.3f} | "
            f"{mc.r_squared:.3f} | {mc.n_obs:,} |"
        )

    # --- 2.3 OADR spline on clean rolling index ---
    lines.append("\n## 2.3 OADR Spline on Clean Rolling Index\n")
    lines.append("| Knot | Below | p | Above | p | R² |")
    lines.append("|------|-------|---|-------|---|-----|")

    spline_results = {}
    for knot in [0.15, 0.20, 0.25, 0.30]:
        df[f'oadr_below_{int(knot*100)}'] = df['old_dep'].clip(upper=knot)
        df[f'oadr_above_{int(knot*100)}'] = (df['old_dep'] - knot).clip(lower=0)
        spline_vars = [f'oadr_below_{int(knot*100)}', f'oadr_above_{int(knot*100)}'] + controls
        est_sp = df.dropna(subset=[dep_var] + spline_vars).copy()
        if len(est_sp) < 200:
            continue
        model = PanelGLS()
        model.fit(est_sp[dep_var].values, est_sp[spline_vars].values,
                  est_sp['iso3'].values, est_sp['year'].values)
        lines.append(
            f"| {knot:.0%} | {fmt_coef(model.beta[0], model.pvalues[0])} | {model.pvalues[0]:.3f} | "
            f"{fmt_coef(model.beta[1], model.pvalues[1])} | {model.pvalues[1]:.3f} | {model.r_squared:.4f} |"
        )
        spline_results[knot] = model
        print(f"  Spline knot={knot:.0%}: below={model.beta[0]:.3f} (p={model.pvalues[0]:.3f}), "
              f"above={model.beta[1]:.3f} (p={model.pvalues[1]:.3f})")

    # --- 2.3b OADR spline on growth component directly ---
    lines.append("\n## 2.3b OADR Spline on Growth Component (Rolling)\n")
    lines.append("| Knot | Below | p | Above | p | R² |")
    lines.append("|------|-------|---|-------|---|-----|")

    growth_dep = 'z_growth_rolling'
    for knot in [0.15, 0.20, 0.25, 0.30]:
        spline_vars = [f'oadr_below_{int(knot*100)}', f'oadr_above_{int(knot*100)}'] + controls
        est_gr = df.dropna(subset=[growth_dep] + spline_vars).copy()
        if len(est_gr) < 200:
            continue
        model = PanelGLS()
        model.fit(est_gr[growth_dep].values, est_gr[spline_vars].values,
                  est_gr['iso3'].values, est_gr['year'].values)
        lines.append(
            f"| {knot:.0%} | {fmt_coef(model.beta[0], model.pvalues[0])} | {model.pvalues[0]:.3f} | "
            f"{fmt_coef(model.beta[1], model.pvalues[1])} | {model.pvalues[1]:.3f} | {model.r_squared:.4f} |"
        )

    # --- 2.4 Life expectancy quadratic ---
    if 'life_expectancy' in df.columns:
        df['le_sq'] = df['life_expectancy'] ** 2
        le_vars = ['life_expectancy', 'le_sq'] + controls
        est_le = df.dropna(subset=[dep_var] + le_vars).copy()
        if len(est_le) >= 200:
            m_le, r_le = fit_and_report(
                est_le[dep_var].values, est_le[le_vars].values,
                est_le['iso3'].values, est_le['year'].values,
                le_vars, "LE Quadratic on Clean Rolling Index"
            )
            all_results.append(r_le)

            le_idx = le_vars.index('life_expectancy')
            le_sq_idx = le_vars.index('le_sq')
            if m_le.beta[le_sq_idx] != 0:
                turning_point = -m_le.beta[le_idx] / (2 * m_le.beta[le_sq_idx])
            else:
                turning_point = np.nan

            lines.append(f"\n## 2.4 Life Expectancy Quadratic\n")
            lines.append(f"- LE coef: {m_le.beta[le_idx]:.4f} (p={m_le.pvalues[le_idx]:.4f})")
            lines.append(f"- LE² coef: {m_le.beta[le_sq_idx]:.6f} (p={m_le.pvalues[le_sq_idx]:.4f})")
            lines.append(f"- Turning point: {turning_point:.1f} years")
            lines.append(f"- R² = {m_le.r_squared:.4f}, N = {m_le.n_obs:,}")

    # Save
    results_df = pd.concat(all_results, ignore_index=True)
    results_df.to_csv(TABLE_DIR / "phase7_baseline_clean_results.csv", index=False)

    md_path = TABLE_DIR / "phase7_baseline_clean.md"
    md_path.write_text('\n'.join(lines), encoding='utf-8')
    print(f"\n  Saved: {md_path}")

    return df, results_df


# =====================================================================
# PART 3: Structural Break Decomposition
# =====================================================================
def part3_structural_break(df):
    print("\n" + "=" * 70)
    print("PART 3: Structural Break Decomposition")
    print("=" * 70)

    lines = ["# Phase 7 Part 3: Structural Break Decomposition\n"]
    dep_var = 'japan_index_2c_rolling_clean'
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    base_vars = demo_vars + controls

    all_results = []

    # Load full_panel for GDP weights
    fp = pd.read_csv(MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv")
    fp_gdp = fp[['iso3', 'year', 'ngdp_usd', 'rgdp_growth']].dropna(subset=['ngdp_usd', 'rgdp_growth'])

    # --- 3.1 Balanced panel ---
    lines.append("## 3.1 Balanced Panel (Countries Present in Both Periods)\n")

    pre_countries = set(df[(df['year'] >= 1990) & (df['year'] <= 2007)]
                       .dropna(subset=[dep_var] + base_vars)
                       .groupby('iso3').filter(lambda x: len(x) >= 5)['iso3'].unique())
    post_countries = set(df[(df['year'] >= 2009) & (df['year'] <= 2024)]
                        .dropna(subset=[dep_var] + base_vars)
                        .groupby('iso3').filter(lambda x: len(x) >= 5)['iso3'].unique())
    balanced = pre_countries & post_countries
    lines.append(f"- Pre-GFC countries (≥5 obs): {len(pre_countries)}")
    lines.append(f"- Post-GFC countries (≥5 obs): {len(post_countries)}")
    lines.append(f"- Balanced set: {len(balanced)} countries\n")

    lines.append("| Period | Z₁ | p | Z₂ | p | Z₃ | p | R² | N |")
    lines.append("|--------|-----|---|-----|---|-----|---|-----|---|")

    for label, yr_mask in [('Pre-GFC (balanced)', df['year'] <= 2007),
                           ('Post-GFC (balanced)', df['year'] >= 2009)]:
        sub = df[yr_mask & df['iso3'].isin(balanced)].dropna(subset=[dep_var] + base_vars)
        if len(sub) >= 100:
            m, r = fit_and_report(
                sub[dep_var].values, sub[base_vars].values,
                sub['iso3'].values, sub['year'].values,
                base_vars, label
            )
            all_results.append(r)
            lines.append(
                f"| {label} | {fmt_coef(m.beta[0], m.pvalues[0])} | {m.pvalues[0]:.3f} | "
                f"{fmt_coef(m.beta[1], m.pvalues[1])} | {m.pvalues[1]:.3f} | "
                f"{fmt_coef(m.beta[2], m.pvalues[2])} | {m.pvalues[2]:.3f} | "
                f"{m.r_squared:.3f} | {m.n_obs:,} |"
            )

    # --- 3.2 OECD vs non-OECD ---
    lines.append("\n## 3.2 OECD vs Non-OECD Pre/Post GFC\n")

    # OECD list (approximate)
    OECD = [
        'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK',
        'EST', 'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR',
        'ITA', 'JPN', 'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL',
        'NOR', 'POL', 'PRT', 'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR',
        'GBR', 'USA',
    ]
    df['is_oecd'] = df['iso3'].isin(OECD).astype(int)

    lines.append("| Group | Period | Z₁ | p | R² | N | N_c |")
    lines.append("|-------|--------|-----|---|----|---|-----|")

    for group_label, group_mask in [('OECD', df['is_oecd'] == 1),
                                     ('Non-OECD', df['is_oecd'] == 0)]:
        for period_label, yr_mask in [('Pre-GFC', df['year'] <= 2007),
                                       ('Post-GFC', df['year'] >= 2009)]:
            sub = df[group_mask & yr_mask].dropna(subset=[dep_var] + base_vars)
            if len(sub) >= 80:
                m, r = fit_and_report(
                    sub[dep_var].values, sub[base_vars].values,
                    sub['iso3'].values, sub['year'].values,
                    base_vars, f"{group_label}: {period_label}"
                )
                all_results.append(r)
                lines.append(
                    f"| {group_label} | {period_label} | {fmt_coef(m.beta[0], m.pvalues[0])} | "
                    f"{m.pvalues[0]:.3f} | {m.r_squared:.3f} | {m.n_obs:,} | {m.n_countries} |"
                )

    # --- 3.3 Income quartile splits ---
    lines.append("\n## 3.3 Income Quartile × Period\n")

    # Use country-median GDP/cap to assign quartiles
    country_gdp = df.groupby('iso3')['gdp_pc_ppp'].median()
    quartile_map = pd.qcut(country_gdp, 4, labels=['Q1', 'Q2', 'Q3', 'Q4'])
    df['income_q'] = df['iso3'].map(quartile_map)

    lines.append("| Quartile | Period | Z₁ | p | R² | N |")
    lines.append("|----------|--------|-----|---|----|---|")

    for q in ['Q1', 'Q2', 'Q3', 'Q4']:
        for period_label, yr_mask in [('Pre-GFC', df['year'] <= 2007),
                                       ('Post-GFC', df['year'] >= 2009)]:
            sub = df[(df['income_q'] == q) & yr_mask].dropna(subset=[dep_var] + base_vars)
            if len(sub) >= 80:
                m, r = fit_and_report(
                    sub[dep_var].values, sub[base_vars].values,
                    sub['iso3'].values, sub['year'].values,
                    base_vars, f"{q}: {period_label}"
                )
                all_results.append(r)
                lines.append(
                    f"| {q} | {period_label} | {fmt_coef(m.beta[0], m.pvalues[0])} | "
                    f"{m.pvalues[0]:.3f} | {m.r_squared:.3f} | {m.n_obs:,} |"
                )
            else:
                lines.append(f"| {q} | {period_label} | -- | -- | -- | {len(sub)} |")

    # --- 3.4 Global factor control ---
    lines.append("\n## 3.4 Global Factor Control\n")

    # Compute GDP-weighted world growth by year
    world_g = (fp_gdp.groupby('year')
               .apply(lambda g: np.average(g['rgdp_growth'], weights=g['ngdp_usd'].clip(lower=1)),
                      include_groups=False)
               .reset_index())
    world_g.columns = ['year', 'world_growth']
    df = df.merge(world_g, on='year', how='left')

    world_vars = base_vars + ['world_growth']

    lines.append("| Period | Z₁ | p | world_growth | p_wg | R² | N |")
    lines.append("|--------|-----|---|-------------|------|-----|---|")

    for label, yr_mask in [('Full', pd.Series(True, index=df.index)),
                           ('Pre-GFC', df['year'] <= 2007),
                           ('Post-GFC', df['year'] >= 2009)]:
        sub = df[yr_mask].dropna(subset=[dep_var] + world_vars)
        if len(sub) >= 100:
            m, r = fit_and_report(
                sub[dep_var].values, sub[world_vars].values,
                sub['iso3'].values, sub['year'].values,
                world_vars, f"Global factor: {label}"
            )
            all_results.append(r)
            wg_idx = world_vars.index('world_growth')
            lines.append(
                f"| {label} | {fmt_coef(m.beta[0], m.pvalues[0])} | {m.pvalues[0]:.3f} | "
                f"{fmt_coef(m.beta[wg_idx], m.pvalues[wg_idx])} | {m.pvalues[wg_idx]:.3f} | "
                f"{m.r_squared:.3f} | {m.n_obs:,} |"
            )

    # Save
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        results_df.to_csv(TABLE_DIR / "phase7_structural_break_results.csv", index=False)

    md_path = TABLE_DIR / "phase7_structural_break_decomp.md"
    md_path.write_text('\n'.join(lines), encoding='utf-8')
    print(f"\n  Saved: {md_path}")

    return df


# =====================================================================
# PART 4: Conditional Projections
# =====================================================================
def part4_projections(df):
    print("\n" + "=" * 70)
    print("PART 4: Conditional Projections")
    print("=" * 70)

    lines = ["# Phase 7 Part 4: Conditional Projections\n"]
    dep_var = 'japan_index_2c_rolling_clean'
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    all_vars = demo_vars + controls

    fp = pd.read_csv(MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv")

    # Estimate three sets of coefficients
    scenarios = {}
    for label, yr_mask in [('Full sample', pd.Series(True, index=df.index)),
                           ('Pre-GFC (1990-2007)', df['year'] <= 2007),
                           ('Post-GFC (2009-2024)', df['year'] >= 2009)]:
        sub = df[yr_mask].dropna(subset=[dep_var] + all_vars)
        if len(sub) < 100:
            print(f"  Skipping {label}: {len(sub)} obs")
            continue
        m, _ = fit_and_report(
            sub[dep_var].values, sub[all_vars].values,
            sub['iso3'].values, sub['year'].values,
            all_vars, f"Projection base: {label}"
        )
        scenarios[label] = m

    if not scenarios:
        print("  No scenarios estimated — skipping projections")
        return df

    # Japan threshold (rolling clean index at ~2000)
    jpn_rolling = df[(df['iso3'] == 'JPN') & (df['year'].between(1999, 2001))]
    jpn_rolling = jpn_rolling[dep_var].dropna()
    if len(jpn_rolling) > 0:
        japan_threshold = jpn_rolling.mean()
    else:
        # Fallback to 2c value
        jpn_2c = df[(df['iso3'] == 'JPN') & (df['year'] == 2000)]['japan_index_2c']
        japan_threshold = jpn_2c.values[0] if len(jpn_2c) > 0 else 0.226
    lines.append(f"Japan threshold (rolling clean ~2000): {japan_threshold:.3f}\n")

    # Focus countries
    focus_countries = [
        'JPN', 'DEU', 'ITA', 'GRC', 'KOR', 'CHN', 'THA', 'ESP',
        'USA', 'GBR', 'FRA', 'CAN', 'AUS', 'BRA', 'IND', 'IDN',
        'MEX', 'ZAF', 'NGA', 'SAU', 'TUR', 'POL', 'RUS', 'SGP',
    ]

    proj_years = [2020, 2030, 2040, 2050, 2060]

    # Last observed controls
    last_controls = (df.dropna(subset=controls)
                     .sort_values('year')
                     .groupby('iso3')[controls]
                     .last())

    # Build projections for each scenario
    all_proj_rows = []
    for scenario_name, model in scenarios.items():
        coefs = {all_vars[i]: model.beta[i] for i in range(len(all_vars))}
        const = model.constant

        for iso3 in focus_countries:
            cdata = fp[fp['iso3'] == iso3]
            ctrl = last_controls.loc[iso3] if iso3 in last_controls.index else None
            if ctrl is None or len(cdata) == 0:
                continue

            for year in proj_years:
                yr = cdata[cdata['year'] == year]
                if len(yr) == 0:
                    continue

                demo_effect = sum(coefs[zv] * yr[zv].values[0] for zv in demo_vars)
                control_effect = sum(coefs[cv] * ctrl[cv] for cv in controls)
                total = demo_effect + control_effect + const

                all_proj_rows.append({
                    'scenario': scenario_name,
                    'iso3': iso3,
                    'year': year,
                    'japan_index_projected': total,
                    'exceeds_threshold': total > japan_threshold,
                })

    proj_df = pd.DataFrame(all_proj_rows)

    # Crossing years by scenario
    lines.append("## Crossing Years by Scenario\n")
    lines.append("| Country | Full Sample | Pre-GFC | Post-GFC |")
    lines.append("|---------|-------------|---------|----------|")

    for iso3 in focus_countries:
        row_parts = [f"| {iso3}"]
        for scenario_name in ['Full sample', 'Pre-GFC (1990-2007)', 'Post-GFC (2009-2024)']:
            sub = proj_df[(proj_df['iso3'] == iso3) & (proj_df['scenario'] == scenario_name)]
            above = sub[sub['japan_index_projected'] > japan_threshold]
            if len(above) > 0:
                crossing = int(above['year'].min())
                if crossing <= 2020:
                    row_parts.append(f" Already |")
                else:
                    row_parts.append(f" ~{crossing} |")
            else:
                row_parts.append(" Never |")
        lines.append("".join(row_parts))

    # Side-by-side projections for key countries
    lines.append("\n## Projected Index Values (2050)\n")
    lines.append("| Country | Full Sample | Pre-GFC | Post-GFC |")
    lines.append("|---------|-------------|---------|----------|")

    for iso3 in focus_countries:
        row_parts = [f"| {iso3}"]
        for scenario_name in ['Full sample', 'Pre-GFC (1990-2007)', 'Post-GFC (2009-2024)']:
            sub = proj_df[(proj_df['iso3'] == iso3) &
                          (proj_df['scenario'] == scenario_name) &
                          (proj_df['year'] == 2050)]
            if len(sub) > 0:
                val = sub['japan_index_projected'].values[0]
                row_parts.append(f" {val:.3f} |")
            else:
                row_parts.append(" -- |")
        lines.append("".join(row_parts))

    # Count crossings by scenario
    lines.append("\n## Summary: Countries Crossing Threshold by 2050\n")
    for scenario_name in scenarios:
        sub = proj_df[(proj_df['scenario'] == scenario_name) &
                      (proj_df['year'] <= 2050)]
        n_cross = sub.groupby('iso3')['exceeds_threshold'].any().sum()
        n_total = sub['iso3'].nunique()
        lines.append(f"- **{scenario_name}**: {n_cross}/{n_total} countries cross by 2050")

    # Save
    proj_df.to_csv(TABLE_DIR / "phase7_conditional_projections.csv", index=False)

    md_path = TABLE_DIR / "phase7_conditional_projections.md"
    md_path.write_text('\n'.join(lines), encoding='utf-8')
    print(f"\n  Saved: {md_path}")

    return df


# =====================================================================
# PART 5: Rate Appendix
# =====================================================================
def part5_rate_appendix(df):
    print("\n" + "=" * 70)
    print("PART 5: Rate Appendix (Homogeneous 10y Yield)")
    print("=" * 70)

    lines = ["# Phase 7 Part 5: Rate Channel Appendix\n"]
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    base_vars = demo_vars + controls

    all_results = []

    # --- 5.1 Baseline on homogeneous 3c index ---
    dep_homog = 'japan_index_3c_homog'
    est = df.dropna(subset=[dep_homog] + base_vars).copy()
    lines.append(f"## 5.1 Baseline: Homogeneous 3c Index (10y yields only)\n")
    lines.append(f"- Sample: {len(est):,} obs, {est['iso3'].nunique()} countries\n")

    if len(est) >= 100:
        m, r = fit_and_report(
            est[dep_homog].values, est[base_vars].values,
            est['iso3'].values, est['year'].values,
            base_vars, "Baseline: Homog 3c (10y only)"
        )
        all_results.append(r)

        lines.append("| Variable | Coefficient | SE | p-value |")
        lines.append("|----------|-------------|-------|---------|")
        for i, v in enumerate(base_vars):
            lines.append(f"| {v} | {fmt_coef(m.beta[i], m.pvalues[i])} | {m.se[i]:.3f} | {m.pvalues[i]:.4f} |")
        lines.append(f"\nR² = {m.r_squared:.4f}, N = {m.n_obs:,}")

    # --- 5.2 Pre/post GFC on homogeneous 3c ---
    lines.append("\n## 5.2 Pre/Post GFC on Homogeneous 3c\n")
    lines.append("| Period | Z₁ | p | Z₂ | p | Z₃ | p | R² | N |")
    lines.append("|--------|-----|---|-----|---|-----|---|-----|---|")

    for label, yr_mask in [('Pre-GFC', df['year'] <= 2007),
                           ('Post-GFC', df['year'] >= 2009)]:
        sub = df[yr_mask].dropna(subset=[dep_homog] + base_vars)
        if len(sub) >= 80:
            m, r = fit_and_report(
                sub[dep_homog].values, sub[base_vars].values,
                sub['iso3'].values, sub['year'].values,
                base_vars, f"Homog 3c: {label}"
            )
            all_results.append(r)
            lines.append(
                f"| {label} | {fmt_coef(m.beta[0], m.pvalues[0])} | {m.pvalues[0]:.3f} | "
                f"{fmt_coef(m.beta[1], m.pvalues[1])} | {m.pvalues[1]:.3f} | "
                f"{fmt_coef(m.beta[2], m.pvalues[2])} | {m.pvalues[2]:.3f} | "
                f"{m.r_squared:.3f} | {m.n_obs:,} |"
            )

    # --- 5.3 Z → govt_bond_10y directly ---
    lines.append("\n## 5.3 Direct: Z → 10y Government Bond Yield\n")

    bond_est = df.dropna(subset=['govt_bond_10y'] + base_vars).copy()
    if len(bond_est) >= 100:
        m_bond, r_bond = fit_and_report(
            bond_est['govt_bond_10y'].values, bond_est[base_vars].values,
            bond_est['iso3'].values, bond_est['year'].values,
            base_vars, "Z → govt_bond_10y"
        )
        all_results.append(r_bond)

        lines.append("| Variable | Coefficient | SE | p-value |")
        lines.append("|----------|-------------|-------|---------|")
        for i, v in enumerate(base_vars):
            lines.append(f"| {v} | {fmt_coef(m_bond.beta[i], m_bond.pvalues[i])} | {m_bond.se[i]:.3f} | {m_bond.pvalues[i]:.4f} |")
        lines.append(f"\nR² = {m_bond.r_squared:.4f}, N = {m_bond.n_obs:,}")
        lines.append(f"\n**Rate channel conclusion**: Demographics {'predict' if any(m_bond.pvalues[:3] < 0.1) else 'do NOT predict'} "
                     f"10y bond yields on homogeneous data.")

    # --- 5.4 Z → rate_japan (original heterogeneous) for comparison ---
    lines.append("\n## 5.4 Comparison: Z → rate_japan (Heterogeneous)\n")

    rate_est = df.dropna(subset=['rate_japan'] + base_vars).copy()
    if len(rate_est) >= 100:
        m_rate, r_rate = fit_and_report(
            rate_est['rate_japan'].values, rate_est[base_vars].values,
            rate_est['iso3'].values, rate_est['year'].values,
            base_vars, "Z → rate_japan (heterogeneous)"
        )
        all_results.append(r_rate)

        lines.append(f"- Z₁: {fmt_coef(m_rate.beta[0], m_rate.pvalues[0])} (p={m_rate.pvalues[0]:.4f})")
        lines.append(f"- Z₂: {fmt_coef(m_rate.beta[1], m_rate.pvalues[1])} (p={m_rate.pvalues[1]:.4f})")
        lines.append(f"- Z₃: {fmt_coef(m_rate.beta[2], m_rate.pvalues[2])} (p={m_rate.pvalues[2]:.4f})")
        lines.append(f"- N = {m_rate.n_obs:,}, R² = {m_rate.r_squared:.4f}")

    # Save
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        results_df.to_csv(TABLE_DIR / "phase7_rate_appendix_results.csv", index=False)

    md_path = TABLE_DIR / "phase7_rate_appendix.md"
    md_path.write_text('\n'.join(lines), encoding='utf-8')
    print(f"\n  Saved: {md_path}")


# =====================================================================
# MAIN
# =====================================================================
def main():
    print("=" * 70)
    print("PHASE 7: Reviewer Response")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "japan_panel_indexed.csv")
    print(f"Loaded: {len(df):,} obs, {df['iso3'].nunique()} countries")

    # Part 1: Clean index
    df = part1_clean_index(df)

    # Part 2: Re-estimation
    df, baseline_results = part2_reestimation(df)

    # Part 3: Structural break
    df = part3_structural_break(df)

    # Part 4: Conditional projections
    df = part4_projections(df)

    # Part 5: Rate appendix
    part5_rate_appendix(df)

    # Save updated panel
    out_path = PROCESSED_DIR / "japan_panel_clean.csv"
    df.to_csv(out_path, index=False)
    print(f"\nSaved updated panel: {out_path}")
    print(f"  {len(df):,} obs, new columns: japan_index_2c_clean, japan_index_2c_rolling_clean, "
          f"z_growth_rolling, z_inflation_rolling, japan_index_3c_homog")

    print(f"\n{'=' * 70}")
    print("Phase 7 complete. Output files:")
    print(f"  1. {TABLE_DIR / 'phase7_clean_index_summary.md'}")
    print(f"  2. {TABLE_DIR / 'phase7_baseline_clean.md'}")
    print(f"  3. {TABLE_DIR / 'phase7_structural_break_decomp.md'}")
    print(f"  4. {TABLE_DIR / 'phase7_conditional_projections.md'}")
    print(f"  5. {TABLE_DIR / 'phase7_rate_appendix.md'}")
    print("=" * 70)


if __name__ == "__main__":
    main()
